%%capture
%matplotlib inline
%load_ext autoreload
%autoreload 2
%cd ../src
As in other examples, we use a pre-trained alexnet.
import torchvision.models as models
from torch.nn.modules import Softmax
alexnet = models.alexnet(pretrained=True)
alexnet.classifier.add_module("softmax", Softmax(dim=1))
alexnet.eval();
Regularization is the most important part for max mean activation to pick up the information that is required. In the following, we will examine the effect of different regularizers.
For illustration, we optimize for some channel of the 9th feature layer in alexnet.
from midnite.visualization.base import *
from plot_utils import show
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
).visualize())
Decays the gradient during optimization, i.e. causes less relevant parts of the optimized image to vanish.
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
regularization = [WeightDecay(decay_factor=1e-3)]
).visualize())
Performs simple blurring after each iteration. However, blurring has the issue that edges are not preserved.
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform = BlurTransform()
).visualize())
Very similar to blurring, but preserves edges.
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=BilateralTransform()
).visualize())
After each iteration, random translation, rotation, and scaling are applied. The resulting optimized image is then robust to such transformations.
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=RandomTransform()
).visualize())
Image is scaled after each iteration. This has the advantage that low-frequency patterns can be picked up more easily. Since the image is scaled up after each step, we use an initial size of 50 pixels (instead of the default 250).
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=ResizeTransform(),
init_size=50
).visualize())
During optimization, adds total variation to the loss (punishes difference in adjacent pixels).
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
regularization=[TVRegularization(coefficient=5e2)]
).visualize())
To obtain a better image, we use:
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=RandomTransform()+BlurTransform()+ResizeTransform(),
regularization=[TVRegularization(), WeightDecay(decay_factor=3e-4)],
init_size=50,
).visualize())
For our optimization, the relevant parameters are:
Image transformations and filters are only applied after each iteration. Apart from that, the parameters behave as in any optimization.
In the following, we have a look at a few examples:
import midnite
# use 'cpu' if no GPU available
with midnite.device('cuda:0'):
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=RandomTransform()+BlurTransform()+ResizeTransform(),
regularization=[TVRegularization(), WeightDecay(decay_factor=3e-4)],
init_size=100,
lr=0.001, # default: 0.01
iter_n=5, # default: 12
opt_n=500, # default: 20
).visualize())
show(PixelActivation(
alexnet.features[:9],
SplitSelector(ChannelSplit(), [1]),
transform=RandomTransform()+BlurTransform()+ResizeTransform(scale_fac=1.05),
regularization=[TVRegularization(), WeightDecay(decay_factor=6e-5)],
init_size=25,
lr=0.1, # default: 0.01
iter_n=50, # default: 12
opt_n=10, # default: 20
).visualize())
show(PixelActivation(
alexnet.features[:11],
SplitSelector(NeuronSplit(), [0, 0, 0]),
transform=BlurTransform(),
regularization=[WeightDecay(decay_factor=1e-2)]
).visualize())
We can optimize for a target class to obtain an image of what the networks understand as that class. In this example, we do this on a GoogLeNet (due to the stronger visual semantics) for imagenet class 783 (Screw).
from torch.nn import Sequential, Softmax
from plot_utils import show_normalized
googlenet = Sequential(models.googlenet(pretrained=True), Softmax(dim=1))
googlenet.eval()
with midnite.device('cuda:0'):
show_normalized(PixelActivation(
googlenet,
SplitSelector(NeuronSplit(), [783]),
transform=RandomTransform(scale_fac=0)+BilateralTransform()+ResizeTransform(1.005),
opt_n=10,
iter_n=300,
regularization=[WeightDecay(decay_factor=1.5e-7), TVRegularization(coefficient=0.05)],
init_size=300,
lr=0.007
).visualize(), scale=1.5)